// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2024 Xi Ruoyao <xry111@xry111.site>. All Rights Reserved.
 */

#include <asm/asm.h>
#include <asm/regdef.h>
#include <linux/linkage.h>

.text

.macro	OP_4REG	op d0 d1 d2 d3 s0 s1 s2 s3
	\op	\d0, \d0, \s0
	\op	\d1, \d1, \s1
	\op	\d2, \d2, \s2
	\op	\d3, \d3, \s3
.endm

/*
 * Very basic LoongArch implementation of ChaCha20. Produces a given positive
 * number of blocks of output with a nonce of 0, taking an input key and
 * 8-byte counter. Importantly does not spill to the stack. Its arguments
 * are:
 *
 *	a0: output bytes
 *	a1: 32-byte key input
 *	a2: 8-byte counter input/output
 *	a3: number of 64-byte blocks to write to output
 */
SYM_FUNC_START(__arch_chacha20_blocks_nostack)

/* We don't need a frame pointer */
#define s9		fp

#define output		a0
#define key		a1
#define counter		a2
#define nblocks		a3
#define i		a4
#define state0		s0
#define state1		s1
#define state2		s2
#define state3		s3
#define state4		s4
#define state5		s5
#define state6		s6
#define state7		s7
#define state8		s8
#define state9		s9
#define state10		a5
#define state11		a6
#define state12		a7
#define state13		t0
#define state14		t1
#define state15		t2
#define cnt_lo		t3
#define cnt_hi		t4
#define copy0		t5
#define copy1		t6
#define copy2		t7

/* Reuse i as copy3 */
#define copy3		i

/* Packs to be used with OP_4REG */
#define line0		state0, state1, state2, state3
#define line1		state4, state5, state6, state7
#define line2		state8, state9, state10, state11
#define line3		state12, state13, state14, state15

#define line1_perm	state5, state6, state7, state4
#define line2_perm	state10, state11, state8, state9
#define line3_perm	state15, state12, state13, state14

#define copy		copy0, copy1, copy2, copy3

#define _16		16, 16, 16, 16
#define _20		20, 20, 20, 20
#define _24		24, 24, 24, 24
#define _25		25, 25, 25, 25

	/*
	 * The ABI requires s0-s9 saved, and sp aligned to 16-byte.
	 * This does not violate the stack-less requirement: no sensitive data
	 * is spilled onto the stack.
	 */
	PTR_ADDI	sp, sp, (-SZREG * 10) & STACK_ALIGN
	REG_S		s0, sp, 0
	REG_S		s1, sp, SZREG
	REG_S		s2, sp, SZREG * 2
	REG_S		s3, sp, SZREG * 3
	REG_S		s4, sp, SZREG * 4
	REG_S		s5, sp, SZREG * 5
	REG_S		s6, sp, SZREG * 6
	REG_S		s7, sp, SZREG * 7
	REG_S		s8, sp, SZREG * 8
	REG_S		s9, sp, SZREG * 9

	li.w		copy0, 0x61707865
	li.w		copy1, 0x3320646e
	li.w		copy2, 0x79622d32

	ld.w		cnt_lo, counter, 0
	ld.w		cnt_hi, counter, 4

.Lblock:
	/* state[0,1,2,3] = "expand 32-byte k" */
	move		state0, copy0
	move		state1, copy1
	move		state2, copy2
	li.w		state3, 0x6b206574

	/* state[4,5,..,11] = key */
	ld.w		state4, key, 0
	ld.w		state5, key, 4
	ld.w		state6, key, 8
	ld.w		state7, key, 12
	ld.w		state8, key, 16
	ld.w		state9, key, 20
	ld.w		state10, key, 24
	ld.w		state11, key, 28

	/* state[12,13] = counter */
	move		state12, cnt_lo
	move		state13, cnt_hi

	/* state[14,15] = 0 */
	move		state14, zero
	move		state15, zero

	li.w		i, 10
.Lpermute:
	/* odd round */
	OP_4REG	add.w	line0, line1
	OP_4REG	xor	line3, line0
	OP_4REG	rotri.w	line3, _16

	OP_4REG	add.w	line2, line3
	OP_4REG	xor	line1, line2
	OP_4REG	rotri.w	line1, _20

	OP_4REG	add.w	line0, line1
	OP_4REG	xor	line3, line0
	OP_4REG	rotri.w	line3, _24

	OP_4REG	add.w	line2, line3
	OP_4REG	xor	line1, line2
	OP_4REG	rotri.w	line1, _25

	/* even round */
	OP_4REG	add.w	line0, line1_perm
	OP_4REG	xor	line3_perm, line0
	OP_4REG	rotri.w	line3_perm, _16

	OP_4REG	add.w	line2_perm, line3_perm
	OP_4REG	xor	line1_perm, line2_perm
	OP_4REG	rotri.w	line1_perm, _20

	OP_4REG	add.w	line0, line1_perm
	OP_4REG	xor	line3_perm, line0
	OP_4REG	rotri.w	line3_perm, _24

	OP_4REG	add.w	line2_perm, line3_perm
	OP_4REG	xor	line1_perm, line2_perm
	OP_4REG	rotri.w	line1_perm, _25

	addi.w		i, i, -1
	bnez		i, .Lpermute

	/*
	 * copy[3] = "expa", materialize it here because copy[3] shares the
	 * same register with i which just became dead.
	 */
	li.w		copy3, 0x6b206574

	/* output[0,1,2,3] = copy[0,1,2,3] + state[0,1,2,3] */
	OP_4REG	add.w	line0, copy
	st.w		state0, output, 0
	st.w		state1, output, 4
	st.w		state2, output, 8
	st.w		state3, output, 12

	/* from now on state[0,1,2,3] are scratch registers  */

	/* state[0,1,2,3] = lo32(key) */
	ld.w		state0, key, 0
	ld.w		state1, key, 4
	ld.w		state2, key, 8
	ld.w		state3, key, 12

	/* output[4,5,6,7] = state[0,1,2,3] + state[4,5,6,7] */
	OP_4REG	add.w	line1, line0
	st.w		state4, output, 16
	st.w		state5, output, 20
	st.w		state6, output, 24
	st.w		state7, output, 28

	/* state[0,1,2,3] = hi32(key) */
	ld.w		state0, key, 16
	ld.w		state1, key, 20
	ld.w		state2, key, 24
	ld.w		state3, key, 28

	/* output[8,9,10,11] = state[0,1,2,3] + state[8,9,10,11] */
	OP_4REG	add.w	line2, line0
	st.w		state8, output, 32
	st.w		state9, output, 36
	st.w		state10, output, 40
	st.w		state11, output, 44

	/* output[12,13,14,15] = state[12,13,14,15] + [cnt_lo, cnt_hi, 0, 0] */
	add.w		state12, state12, cnt_lo
	add.w		state13, state13, cnt_hi
	st.w		state12, output, 48
	st.w		state13, output, 52
	st.w		state14, output, 56
	st.w		state15, output, 60

	/* ++counter  */
	addi.w		cnt_lo, cnt_lo, 1
	sltui		state0, cnt_lo, 1
	add.w		cnt_hi, cnt_hi, state0

	/* output += 64 */
	PTR_ADDI	output, output, 64
	/* --nblocks */
	PTR_ADDI	nblocks, nblocks, -1
	bnez		nblocks, .Lblock

	/* counter = [cnt_lo, cnt_hi] */
	st.w		cnt_lo, counter, 0
	st.w		cnt_hi, counter, 4

	/*
	 * Zero out the potentially sensitive regs, in case nothing uses these
	 * again. As at now copy[0,1,2,3] just contains "expand 32-byte k" and
	 * state[0,...,9] are s0-s9 those we'll restore in the epilogue, so we
	 * only need to zero state[11,...,15].
	 */
	move		state10, zero
	move		state11, zero
	move		state12, zero
	move		state13, zero
	move		state14, zero
	move		state15, zero

	REG_L		s0, sp, 0
	REG_L		s1, sp, SZREG
	REG_L		s2, sp, SZREG * 2
	REG_L		s3, sp, SZREG * 3
	REG_L		s4, sp, SZREG * 4
	REG_L		s5, sp, SZREG * 5
	REG_L		s6, sp, SZREG * 6
	REG_L		s7, sp, SZREG * 7
	REG_L		s8, sp, SZREG * 8
	REG_L		s9, sp, SZREG * 9
	PTR_ADDI	sp, sp, -((-SZREG * 10) & STACK_ALIGN)

	jr		ra
SYM_FUNC_END(__arch_chacha20_blocks_nostack)
